package common;

import java.io.File;
import java.io.FileFilter;
import java.lang.annotation.Annotation;
import java.net.JarURLConnection;
import java.net.URL;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import org.apache.commons.lang.StringUtils;
import jodd.util.StringUtil;

public class ClassUtils
{

    // 获取指定包名下的所有类
    public static List<Class<?>> getClassList(String packageName, boolean isRecursive)
    {
        List<Class<?>> classList = new ArrayList<Class<?>>();
        try
        {
            Enumeration<URL> urls =
                            Thread.currentThread().getContextClassLoader().getResources(packageName.replace(".", "/"));
            while (urls.hasMoreElements())
            {
                URL url = urls.nextElement();
                if (url != null)
                {
                    String protocol = url.getProtocol();
                    if (protocol.equals("file"))
                    {
                        String packagePath = url.getPath();
                        addClass(classList, packagePath, packageName, isRecursive);
                    }
                    else if (protocol.equals("jar"))
                    {
                        JarURLConnection jarURLConnection = (JarURLConnection) url.openConnection();
                        JarFile jarFile = jarURLConnection.getJarFile();
                        Enumeration<JarEntry> jarEntries = jarFile.entries();
                        while (jarEntries.hasMoreElements())
                        {
                            JarEntry jarEntry = jarEntries.nextElement();
                            String jarEntryName = jarEntry.getName();
                            if (jarEntryName.endsWith(".class"))
                            {
                                String className =
                                                jarEntryName.substring(0, jarEntryName.lastIndexOf(".")).replaceAll(
                                                    "/", ".");
                                if (isRecursive
                                                || className.substring(0, className.lastIndexOf(".")).equals(
                                                    packageName))
                                {
                                    classList.add(Class.forName(className));
                                }
                            }
                        }
                    }
                }
            }
        }
        catch (Exception e)
        {
            throw new RuntimeException(e);
        }
        return classList;
    }

    // 获取指定包名下指定注解的所有类
    public static List<Class<?>> getClassListByAnnotation(String packageName,
                    Class<? extends Annotation> annotationClass)
    {
        List<Class<?>> classList = new ArrayList<Class<?>>();
        try
        {
            Enumeration<URL> urls =
                            Thread.currentThread().getContextClassLoader().getResources(packageName.replace(".", "/"));
            while (urls.hasMoreElements())
            {
                URL url = urls.nextElement();
                if (url != null)
                {
                    String protocol = url.getProtocol();
                    if (protocol.equals("file"))
                    {
                        String packagePath = url.getPath();
                        addClassByAnnotation(classList, packagePath, packageName, annotationClass);
                    }
                    else if (protocol.equals("jar"))
                    {
                        JarURLConnection jarURLConnection = (JarURLConnection) url.openConnection();
                        JarFile jarFile = jarURLConnection.getJarFile();
                        Enumeration<JarEntry> jarEntries = jarFile.entries();
                        while (jarEntries.hasMoreElements())
                        {
                            JarEntry jarEntry = jarEntries.nextElement();
                            String jarEntryName = jarEntry.getName();
                            if (jarEntryName.endsWith(".class"))
                            {
                                String className =
                                                jarEntryName.substring(0, jarEntryName.lastIndexOf(".")).replaceAll(
                                                    "/", ".");
                                Class<?> cls = Class.forName(className);
                                if (cls.isAnnotationPresent(annotationClass))
                                {
                                    classList.add(cls);
                                }
                            }
                        }
                    }
                }
            }
        }
        catch (Exception e)
        {
            throw new RuntimeException(e);
        }
        return classList;
    }

    // 获取指定包名下指定父类的所有类
    public static List<Class<?>> getClassListBySuper(String packageName, Class<?> superClass)
    {
        List<Class<?>> classList = new ArrayList<Class<?>>();
        try
        {
            Enumeration<URL> urls =
                            Thread.currentThread().getContextClassLoader().getResources(packageName.replace(".", "/"));
            while (urls.hasMoreElements())
            {
                URL url = urls.nextElement();
                if (url != null)
                {
                    String protocol = url.getProtocol();
                    if (protocol.equals("file"))
                    {
                        String packagePath = url.getPath();
                        addClassBySuper(classList, packagePath, packageName, superClass);
                    }
                    else if (protocol.equals("jar"))
                    {
                        JarURLConnection jarURLConnection = (JarURLConnection) url.openConnection();
                        JarFile jarFile = jarURLConnection.getJarFile();
                        Enumeration<JarEntry> jarEntries = jarFile.entries();
                        while (jarEntries.hasMoreElements())
                        {
                            JarEntry jarEntry = jarEntries.nextElement();
                            String jarEntryName = jarEntry.getName();
                            if (jarEntryName.endsWith(".class"))
                            {
                                String className =
                                                jarEntryName.substring(0, jarEntryName.lastIndexOf(".")).replaceAll(
                                                    "/", ".");
                                Class<?> cls = Class.forName(className);
                                if (superClass.isAssignableFrom(cls) && !superClass.equals(cls))
                                {
                                    classList.add(cls);
                                }
                            }
                        }
                    }
                }
            }
        }
        catch (Exception e)
        {
            throw new RuntimeException(e);
        }
        return classList;
    }

    // 获取指定包名下指定接口的所有实现类
    public static List<Class<?>> getClassListByInterface(String packageName, Class<?> interfaceClass)
    {
        List<Class<?>> classList = new ArrayList<Class<?>>();
        try
        {
            Enumeration<URL> urls =
                            Thread.currentThread().getContextClassLoader().getResources(packageName.replace(".", "/"));
            while (urls.hasMoreElements())
            {
                URL url = urls.nextElement();
                if (url != null)
                {
                    String protocol = url.getProtocol();
                    if (protocol.equals("file"))
                    {
                        String packagePath = url.getPath();
                        addClassByInterface(classList, packagePath, packageName, interfaceClass);
                    }
                    else if (protocol.equals("jar"))
                    {
                        JarURLConnection jarURLConnection = (JarURLConnection) url.openConnection();
                        JarFile jarFile = jarURLConnection.getJarFile();
                        Enumeration<JarEntry> jarEntries = jarFile.entries();
                        while (jarEntries.hasMoreElements())
                        {
                            JarEntry jarEntry = jarEntries.nextElement();
                            String jarEntryName = jarEntry.getName();
                            if (jarEntryName.endsWith(".class"))
                            {
                                String className =
                                                jarEntryName.substring(0, jarEntryName.lastIndexOf(".")).replaceAll(
                                                    "/", ".");
                                Class<?> cls = Class.forName(className);
                                if (interfaceClass.isAssignableFrom(cls) && !interfaceClass.equals(cls))
                                {
                                    classList.add(cls);
                                }
                            }
                        }
                    }
                }
            }
        }
        catch (Exception e)
        {
            throw new RuntimeException(e);
        }
        return classList;
    }

    private static void addClass(List<Class<?>> classList, String packagePath, String packageName, boolean isRecursive)
    {
        try
        {
            File[] files = getClassFiles(packagePath);
            if (files != null)
            {
                for (File file : files)
                {
                    String fileName = file.getName();
                    if (file.isFile())
                    {
                        String className = getClassName(packageName, fileName);
                        classList.add(Class.forName(className));
                    }
                    else
                    {
                        if (isRecursive)
                        {
                            String subPackagePath = getSubPackagePath(packagePath, fileName);
                            String subPackageName = getSubPackageName(packageName, fileName);
                            addClass(classList, subPackagePath, subPackageName, isRecursive);
                        }
                    }
                }
            }
        }
        catch (Exception e)
        {
            throw new RuntimeException(e);
        }
    }

    final static String EMPTY_BLANK_STR_CONVERT = "%20";
    final static String EMPTY_BLANK_STR = " ";

    private static File[] getClassFiles(String packagePath)
    {
        if (StringUtils.contains(packagePath, EMPTY_BLANK_STR_CONVERT))
        {
            packagePath = StringUtils.replace(packagePath, EMPTY_BLANK_STR_CONVERT, EMPTY_BLANK_STR);
        }
        return new File(packagePath).listFiles(new FileFilter()
        {
            public boolean accept(File file)
            {
                return (file.isFile() && file.getName().endsWith(".class")) || file.isDirectory();
            }
        });
    }

    private static String getClassName(String packageName, String fileName)
    {
        String className = fileName.substring(0, fileName.lastIndexOf("."));
        if (StringUtil.isNotEmpty(packageName))
        {
            className = packageName + "." + className;
        }
        return className;
    }

    private static String getSubPackagePath(String packagePath, String filePath)
    {
        String subPackagePath = filePath;
        if (StringUtil.isNotEmpty(packagePath))
        {
            subPackagePath = packagePath + "/" + subPackagePath;
        }
        return subPackagePath;
    }

    private static String getSubPackageName(String packageName, String filePath)
    {
        String subPackageName = filePath;
        if (StringUtil.isNotEmpty(packageName))
        {
            subPackageName = packageName + "." + subPackageName;
        }
        return subPackageName;
    }

    private static void addClassByAnnotation(List<Class<?>> classList, String packagePath, String packageName,
                    Class<? extends Annotation> annotationClass)
    {
        try
        {
            File[] files = getClassFiles(packagePath);
            if (files != null)
            {
                for (File file : files)
                {
                    String fileName = file.getName();
                    if (file.isFile())
                    {
                        String className = getClassName(packageName, fileName);
                        Class<?> cls = Class.forName(className);
                        if (cls.isAnnotationPresent(annotationClass))
                        {
                            classList.add(cls);
                        }
                    }
                    else
                    {
                        String subPackagePath = getSubPackagePath(packagePath, fileName);
                        String subPackageName = getSubPackageName(packageName, fileName);
                        addClassByAnnotation(classList, subPackagePath, subPackageName, annotationClass);
                    }
                }
            }
        }
        catch (Exception e)
        {
            throw new RuntimeException(e);
        }
    }

    private static void addClassBySuper(List<Class<?>> classList, String packagePath, String packageName,
                    Class<?> superClass)
    {
        try
        {
            File[] files = getClassFiles(packagePath);
            if (files != null)
            {
                for (File file : files)
                {
                    String fileName = file.getName();
                    if (file.isFile())
                    {
                        String className = getClassName(packageName, fileName);
                        Class<?> cls = Class.forName(className);
                        if (superClass.isAssignableFrom(cls) && !superClass.equals(cls))
                        {
                            classList.add(cls);
                        }
                    }
                    else
                    {
                        String subPackagePath = getSubPackagePath(packagePath, fileName);
                        String subPackageName = getSubPackageName(packageName, fileName);
                        addClassBySuper(classList, subPackagePath, subPackageName, superClass);
                    }
                }
            }
        }
        catch (Exception e)
        {
            throw new RuntimeException(e);
        }
    }

    private static void addClassByInterface(List<Class<?>> classList, String packagePath, String packageName,
                    Class<?> interfaceClass)
    {
        try
        {
            File[] files = getClassFiles(packagePath);
            if (files != null)
            {
                for (File file : files)
                {
                    String fileName = file.getName();
                    if (file.isFile())
                    {
                        String className = getClassName(packageName, fileName);
                        Class<?> cls = Class.forName(className);
                        if (interfaceClass.isAssignableFrom(cls) && !interfaceClass.equals(cls))
                        {
                            classList.add(cls);
                        }
                    }
                    else
                    {
                        String subPackagePath = getSubPackagePath(packagePath, fileName);
                        String subPackageName = getSubPackageName(packageName, fileName);
                        addClassByInterface(classList, subPackagePath, subPackageName, interfaceClass);
                    }
                }
            }
        }
        catch (Exception e)
        {
            throw new RuntimeException(e);
        }
    }

    // 获取类路径
    public static String getClassPath()
    {
        String classpath = "";
        URL resource = Thread.currentThread().getContextClassLoader().getResource("");
        if (resource != null)
        {
            classpath = resource.getPath();
        }
        return classpath;
    }

    public static ClassLoader getDefaultClassLoader()
    {
        ClassLoader cl = null;
        try
        {
            cl = Thread.currentThread().getContextClassLoader();
        }
        catch (Throwable ex)
        {
            // Cannot access thread context ClassLoader - falling back to system class loader...
        }
        if (cl == null)
        {
            // No thread context class loader -> use class loader of this class.
            cl = ClassUtils.class.getClassLoader();
        }
        return cl;
    }
}
